"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


Performs clustering of over static time windows and over multilayer network using
infomap of the wild mice dataset. 

Plot Figs. S3 and S4


"""


import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))


import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork



import matplotlib.pyplot as plt
plt.style.use('alex_paper')
from infomap import Infomap
from TemporalStability import Partition

ncpu = 4


num_repeat = 50

datadir = '../paper_data/mice_data_march_april2017/'

savedir = '../paper_data/mice_data_march_april2017/infomap_results'
os.makedirs(savedir, exist_ok=True)

figdir = '../figures/micenet_march_april2017/infomap_results'

file_prefix = 'micenet2007'

net_file = os.path.join(datadir,'micenet_2017_02_28_to_2017_05_01')


raise Exception


            
#%%% net work file

net = ContTempNetwork.load(net_file, attributes_list=['node_to_label_dict',
                      'events_table',
                      'times',
                      'time_grid',
                      'num_nodes',
                      'start_date',
                      'male_array',
                      'female_array'])


    



#%%


# half weeks
slice_length = 60*60*24*7/2


t0 = 0

tend = net.times[-1]


t= t0
time_slices = [t0]
while t <= tend:
    t += slice_length
    time_slices.append(t)
    


t_starts = time_slices[:-1]
t_stops = time_slices[1:]

adjacencies = []

for ts,te in zip(t_starts,t_stops):
    print(ts,te)
    A = net.compute_static_adjacency_matrix(start_time=ts, 
                                            end_time=te)
    if not (A.data == np.zeros_like(A.data)).all():
        adjacencies.append((A))
    
#%%

import networkx as nx

# create pajek files

graphs = []
for i,A in enumerate(adjacencies):
    print(i)

    G = nx.convert_matrix.from_scipy_sparse_matrix(A)
    
    graphs.append(G)
    
    
    nx.write_pajek(G, path=os.path.join(savedir, f'halfweekly_slice_{i:03d}.net'))



#%% write multilayer pajek

with open(os.path.join(savedir, 'halfweekly_multilayer.net'), 'w') as fopen:
    fopen.write(f'*Vertices {net.num_nodes}\n')
    for n in range(net.num_nodes):
        fopen.write(f'{n+1} "{n}"\n')
        
    fopen.write('*Intra\n')
    for l,G in enumerate(graphs):
        for s,t,w in G.edges(data='weight'):
            # layer_id node_id node_id weight
            fopen.write(f'{l+1} {s+1} {t+1} {w}"\n')            


            
#%%            

window = 'weekly'
window = 'halfweekly'

if window == 'weekly':
    num_windows = 9
if window == 'halfweekly':
    num_windows = 18
    
p_coarse = []
p_fine = []
p_lev = []

for i in range(num_windows):
    
    im = Infomap()
    
    im.read_file(os.path.join(savedir, f'{window}_slice_{i:03d}.net'))
    
    im.run()
    
    im.get_modules(depth_level=1)
    
    p_coarse.append(Partition(num_nodes=im.num_nodes, 
              node_to_cluster_dict={n-1:c for n,c in im.get_modules(depth_level=1).items()}))
    
    p_fine.append(Partition(num_nodes=im.num_nodes, 
              node_to_cluster_dict={n-1:c for n,c in im.get_modules(depth_level=-1).items()}))
    
    lev_parts = {}
    for l in range(im.num_levels):
        lev_parts[l] = Partition(num_nodes=im.num_nodes, 
                  node_to_cluster_dict={n-1:c for n,c in im.get_modules(depth_level=l).items()})
    
    p_lev.append(lev_parts)
    

#%% remove unactive mice

c_list_coarse = []
c_list_fine = []
c_all_levels = []
active_mice_list = []

for i in range(num_windows):
    
    g = graphs[i]
    
    degs = dict(g.degree(range(g.number_of_nodes())))
    
    active_mice = set([n for n,k in degs.items() if k > 0])
    
    c_list_coarse.append([c & active_mice for c in p_coarse[i].cluster_list])
    c_list_coarse[i] = [c for c in c_list_coarse[i] if len(c)>0]
    
    c_list_fine.append([c & active_mice for c in p_fine[i].cluster_list])
    c_list_fine[i] = [c for c in c_list_fine[i] if len(c)>0]
    

    c_all_levels.append(
        {l : [c & active_mice for c in p_lev[i][l].cluster_list if len(c & active_mice)>0] for l in p_lev[i].keys()}
        )
    
    
    active_mice_list.append(active_mice)
#%%

num_c_coarse = [len(cl) for cl in c_list_coarse]
num_c_fine = [len(cl) for cl in c_list_fine]

num_levels = [len(l.keys()) for l in c_all_levels]

num_c_all_lev = {}
for l in range(max(num_levels)):
    num_c_all_lev[l] = [len(cl[l]) if l in cl else np.nan for cl in c_all_levels ]
    
#%%    
plt.rc('font',size=12)
l=1
plt.figure()
for l in range(1,max(num_levels)):
    plt.plot(np.arange(len(num_c_all_lev[l])),num_c_all_lev[l],'-o',label=f'level {l}')
    plt.xlabel('week')
    plt.ylabel('num clust per snapshot')
    
plt.legend()

#%%

plt.figure()
plt.plot(np.arange(len(num_c_fine)),num_c_coarse,'-o', label='coarsest level')
plt.plot(np.arange(len(num_c_fine)),num_c_fine,'-o', label='finest level')
plt.xlabel('week')
plt.ylabel('num clust. per snapshot')
plt.legend()

#%% flow stab, num communities per week
multi_res = pd.read_pickle( os.path.join('../paper_data/mice_data_march_april2017/', 'multi_weeks_partitions.pickle'))

plt.figure()

for tau_w in [1.0, 86400.0]:
    num_c_per_week = [(len([c for c in p.cluster_list if len(c)>1]), len([c for c in q.cluster_list if len(c)>1])) for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]


    plt.plot(np.arange(len(num_c_per_week)),[a[0] for a in num_c_per_week],'-o', label=f'forw, tau_w={tau_w}')
    plt.plot(np.arange(len(num_c_per_week)),[a[1] for a in num_c_per_week],'-o', label=f'back, tau_w={tau_w}')
plt.xlabel('week')
plt.ylabel('num clust per week (len(c)>1)')
plt.legend()



#%% all toghether

fig, (ax,ax2) = plt.subplots(2,1,sharey=True, figsize=(4.1,7.89))
for l in range(1,max(num_levels)):
    
    
    ax.plot(np.arange(len(num_c_all_lev[l]))*0.5,num_c_all_lev[l],'-x',label=f'level {l}')
    

for tau_w in [1.0, 86400.0]:
    num_c_per_week = [(len([c for c in p.cluster_list if len(c)>1]), len([c for c in q.cluster_list if len(c)>1])) for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]


    ax2.plot(np.arange(len(num_c_per_week)),[a[0] for a in num_c_per_week],'-o', label=fr'forw, $\tau_w$ = {tau_w:.0f} s')
    ax2.plot(np.arange(len(num_c_per_week)),[a[1] for a in num_c_per_week],'-o', label=fr'back, $\tau_w$ = {tau_w:.0f} s')

ax.set_xlabel('week')
ax.set_ylabel('num clust per half week')
ax2.set_xlabel('week')
ax2.set_ylabel('num clust per week')
# ax.legend(loc='lower right', fontsize=10)

ax.legend(bbox_to_anchor=(0., 0.0, 1., .102), loc='lower left',
                      ncol=4, mode="expand", borderaxespad=0.,
                      fontsize=9)
ax2.legend(fontsize=10)
ax.set_ylim([0,32.4])
# plt.savefig(os.path.join(figdir, f'comparison_{window}_num_clust_per_week.pdf'))

#%% track communities
from majortrack import MajorTrack
from copy import deepcopy
import matplotlib.gridspec as gridspec

# Define the data time-sequence data
_time_windows = [[i, i+1] for i in range(num_windows)]


time_windows = [[10*el for el in tw] for tw in _time_windows]

# #############################################################################
# Initiate the algorithm
mt_coarse = MajorTrack(
        clusterings=c_list_coarse,
        individuals=active_mice_list,
        history=0,
        timepoints=[tw[0] for tw in time_windows]
    )
mt_coarse.get_group_matchup('fraction')

mt_fine = MajorTrack(
        clusterings=c_list_fine,
        individuals=active_mice_list,
        history=0,
        timepoints=[tw[0] for tw in time_windows]
    )
mt_fine.get_group_matchup('fraction')





mt_coarse.history = 8
mt_fine.history = 8

mt_coarse.get_dcs()
mt_coarse.get_community_group_membership()
mt_coarse.get_community_membership()
mt_coarse.get_community_coloring()

mt_fine.get_dcs()
mt_fine.get_community_group_membership()
mt_fine.get_community_membership()
mt_fine.get_community_coloring()
    
# plotting params
plot_params = {
        'cluster_width': 2,
        'flux_kwargs': {'alpha': 0.2, 'lw': 0.0, 'facecolor': 'cluster'},
        'cluster_kwargs': {'alpha': 1.0, 'lw': 0.0},
        'label_kwargs': {'fontweight': 'heavy'},
        'with_cluster_labels': False,
        'cluster_label': 'group_index',
        'cluster_label_margin': (-1.6, 0.1),
        'x_axis_offset': 0.07,
        'redistribute_vertically': 1,
        'cluster_location': 'center',
        # 'y_fix': {
        #     20.0: [('4', '7'), ('0', '1'), ('4', '3')],
        #     30.0: [('0', '3')]
        #     }
        }

rawmt = deepcopy(mt_coarse)

# Single
# #############################################################################
# The trace back (memory) part
# the merging illustration
sankey_plot_params = dict(plot_params)
sankey_plot_params.update({
        'merged_edgecolor': 'none',  # 'xkcd:gray',
        'merged_linewidth': 1,
        'cluster_facecolor': 'community',
        'cluster_edgecolor': 'community',
        'flux_facecolor': 'cluster',
        'flux_edgecolor': 'cluster',
        'l_size' : 9,
        })

#%%

fig, ax = plt.subplots(1,1, figsize=(6.4,4.8))
mt_coarse.get_alluvialdiagram(ax, **sankey_plot_params)


#%%
plt.savefig(os.path.join(figdir, f'alluv_{window}_history8step_coarse_partition.pdf'))

#%%
fig, ax = plt.subplots(1,1, figsize=(6.4,4.8))
mt_fine.get_alluvialdiagram(ax, **sankey_plot_params)

#%%
plt.savefig(os.path.join(figdir, f'alluv_{window}_history8step_fine_partition.pdf'))


#%% paper figure

fig = plt.figure()
gs = gridspec.GridSpec(2, 2)

ax = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[1,0], sharey=ax)
ax3 = fig.add_subplot(gs[0,1])
ax4 = fig.add_subplot(gs[1,1])

for l in range(1,max(num_levels)):
    
    
    ax.plot(np.arange(len(num_c_all_lev[l]))*0.5,num_c_all_lev[l],'-x',label=f'level {l}')
    
    
for tau_w in [1.0, 86400.0]:
    num_c_per_week = [(len([c for c in p.cluster_list if len(c)>1]), len([c for c in q.cluster_list if len(c)>1])) for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]


    ax2.plot(np.arange(len(num_c_per_week)),[a[0] for a in num_c_per_week],'-o', label=fr'forw, $\tau_w$ = {tau_w:.0f} s')
    ax2.plot(np.arange(len(num_c_per_week)),[a[1] for a in num_c_per_week],'-o', label=fr'back, $\tau_w$ = {tau_w:.0f} s')

ax.set_xlabel('week')
ax.set_ylabel('num clust per half week')
ax2.set_xlabel('week')
ax2.set_ylabel('num clust per week')
# ax.legend(loc='lower right', fontsize=10)

ax.legend(bbox_to_anchor=(0., 0.0, 1., .102), loc='lower left',
                      ncol=4, mode="expand", borderaxespad=0.,
                      fontsize=9)
ax2.legend(fontsize=10)
ax.set_ylim([0,32.4])

mt_coarse.get_alluvialdiagram(ax3, **sankey_plot_params)
mt_fine.get_alluvialdiagram(ax4, **sankey_plot_params)

ax3.axis('off')
ax4.axis('off')

ax.text(-0.1,1.1, r'\textbf{A}', transform=ax.transAxes, useTex=True)
ax2.text(-0.1,1.1, r'\textbf{B}', transform=ax2.transAxes, useTex=True)
ax3.text(-0.1,1.1, r'\textbf{C}', transform=ax3.transAxes, useTex=True)
ax4.text(-0.1,1.1, r'\textbf{D}', transform=ax4.transAxes, useTex=True)

gs.tight_layout(fig)
#%%
plt.savefig(os.path.join(figdir, f'fig_mice_comparison_alluv_{window}.png'), dpi=600)
plt.savefig(os.path.join(figdir, f'fig_mice_comparison_alluv_{window}.pdf'))




#%% halfweekly multilayer network


relax_rates = [0.001, 0.01]

for rr in relax_rates:
    im = Infomap()
    
    im.read_file(os.path.join(savedir, 'halfweekly_multilayer.net'))
    
    im.run(
            # two_level=True,
           flow_model='undirected',
           multilayer_relax_by_jsd=True,
           multilayer_relax_limit=1,
            multilayer_relax_rate=rr,
           )



#% write read clu


    for level in range(im.num_levels):
        
    
        clu_file = f'halfweekly_multilay_relax_rate{rr}_level{level}.clu'
        
        im.write_clu('../paper_data/mice_data_march_april2017/infomap_results/' + clu_file, 
                      states=True, depth_level=level)
        
        
#%% read


rr=0.01
clust_mat = {}
for level in range(1,6):
    
    clu_file = f'halfweekly_multilay_relax_rate{rr}_level{level}.clu'
    
    clust_mat[level] = np.ones((net.num_nodes,18))*np.nan
    
    with open('../paper_data/mice_data_march_april2017/infomap_results/' + clu_file, 'r') as fopen:
        for line in fopen:
            if not line.startswith('#'):
                state_id, module, flow, node_id, layer_id = line.split()
    
                clust_mat[level][int(node_id)-1,int(layer_id)-1] = int(module)
        
#%% plot
#sort first layer
level = 1
from matplotlib.colors import ListedColormap

color_list = ["#256676", "#1eefc9", "#16894a", "#bde267", "#7c8a4f", "#aedfca", "#1a4fa3", "#63b1f3", "#5c39b4", "#bd9bf4", "#a54984", "#f6248f", "#c068fc", "#3f16f9", "#44f270", "#754819", "#f5a683", "#9a2a06", "#ebc30e", "#eb1138", "#3f4c08", "#fe8f06", "#eb04dc"]
def adjust_lightness(color, amount=0.5):
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return mc.to_hex(colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2]))

color_list += [adjust_lightness(c,amount=0.5) for c in color_list] + \
    [adjust_lightness(c,amount=1.2) for c in color_list] +\
    [adjust_lightness(c,amount=0.25) for c in color_list] + \
        [adjust_lightness(c,amount=1.1) for c in color_list]


num_clusts = int(np.nanmax(clust_mat[level]))

cmap = ListedColormap(color_list[:num_clusts])


sort_idx = np.argsort(clust_mat[level][:,0])


plt.figure(figsize=(8,8))
plt.imshow(clust_mat[level][sort_idx,:], aspect='auto', 
            cmap=cmap,
           )
plt.title(f'relax rate {rr}, level {level}, num_clust {num_clusts}')
plt.xlabel('layer')
plt.ylabel('nodes')
print(np.nanmax(num_clusts))
#%%
plt.savefig(os.path.join(figdir, f'multilayer_relax_rate{rr}_level{level}.png'), dpi=600)
plt.savefig(os.path.join(figdir, f'multilayer_relax_rate{rr}_level{level}.pdf'), dpi=600)



#%% 
plt.style.use('alex_paper')
plt.rc('font', size=12)
plt.figure()
for level in range(1,6):
    
    num_clust_per_layer = []
    for lay in range(clust_mat[level].shape[1]):
        num_clust_per_layer.append(np.unique(clust_mat[level][:,lay][~np.isnan(clust_mat[level][:,lay])]).size)
    print(num_clust_per_layer)
    
    
    plt.plot(np.arange(len(num_clust_per_layer)),num_clust_per_layer,'-o',
             label=f'level {level}')
    
plt.xlabel('layer')
plt.ylabel('num clust. per layer')

plt.legend(fontsize=12)
#%%

plt.savefig(os.path.join(figdir, f'multilayer_relax_rate{rr}_num_clust_per_layer.pdf'))
